{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "# RGB Anomaly Detection\n",
    "\n",
    "---\n",
    "[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://github.com/roboflow/notebooks/blob/main/notebooks/rgb_anomalyDetection.ipynb)\n",
    "\n",
    "In this cookbook, we identify color / RGB anomalies for segmented items. Capture a base image to extract your ground truth RGB with Roboflow and compare to neew data collected. In this scenario, we are assessing variations in logo color.\n",
    "\n",
    "Click the Open in Colab button to run the cookbook on Google Colab.\n",
    "\n",
    "**Let's begin!**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Install required packages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install sklearn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import cv2\n",
    "import numpy as np\n",
    "import time\n",
    "import base64\n",
    "import requests\n",
    "import os, glob\n",
    "from sklearn.cluster import KMeans"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Extract target RGB color from polygon and run Kmeans"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def parse_polygon_annotation(annotation_data, image_shape):\n",
    "    width, height = image_shape[1], image_shape[0]\n",
    "    return [(int(data['x']), int(data['y'])) for data in annotation_data]\n",
    "\n",
    "def extract_polygon_area(image_path, polygon_points):\n",
    "    image = cv2.imread(image_path)\n",
    "    mask = np.zeros(image.shape[:2], dtype=np.uint8)\n",
    "    cv2.drawContours(mask, [np.array(polygon_points)], -1, (255, 255, 255), -1)\n",
    "    return cv2.bitwise_and(image, image, mask=mask)\n",
    "\n",
    "def compute_average_color(image):\n",
    "    mask = np.all(image != [0, 0, 0], axis=2)\n",
    "    avg_color = np.mean(image[mask], axis=0)\n",
    "    return avg_color\n",
    "\n",
    "def color_difference(color1, color2):\n",
    "    return np.linalg.norm(np.array(color1) - np.array(color2))\n",
    "\n",
    "def count_color_matches(dominant_colors, target_colors, threshold):\n",
    "    matches_count = {tuple(target): 0 for target in target_colors}\n",
    "    matched_colors = {tuple(target): [] for target in target_colors}\n",
    "    \n",
    "    for color in dominant_colors:\n",
    "        for target in target_colors:\n",
    "            difference = color_difference(color, target)\n",
    "            \n",
    "            if difference < threshold:\n",
    "                matches_count[tuple(target)] += 1\n",
    "                matched_colors[tuple(target)].append(color)\n",
    "                \n",
    "    return matches_count, matched_colors\n",
    "\n",
    "def get_dominant_colors(image, k=5):\n",
    "    image = image.reshape((image.shape[0] * image.shape[1], 3))\n",
    "    image = image[np.any(image != [0, 0, 0], axis=1)]\n",
    "    kmeans = KMeans(n_clusters=k, n_init='auto')\n",
    "    kmeans.fit(image)\n",
    "    dominant_colors = kmeans.cluster_centers_\n",
    "    return dominant_colors\n",
    "\n",
    "def extract_target_colors(target_image_path,inference_server_address, project_id, version_number):\n",
    "    target_image = cv2.imread(target_image_path)\n",
    "    with open(target_image_path, \"rb\") as f:\n",
    "        im_bytes = f.read()        \n",
    "    im_b64 = base64.b64encode(im_bytes).decode(\"utf8\")\n",
    "\n",
    "    headers = {\n",
    "        'Content-Type': 'application/json',\n",
    "        'Accept': 'application/json'\n",
    "    }\n",
    "\n",
    "    params = {\n",
    "        'api_key': 'FFgkmScNUBERP9t3PJvV',\n",
    "    }\n",
    "\n",
    "    response = requests.post(inference_server_address + project_id + str(version_number), params=params, headers=headers, data=im_b64)\n",
    "    data = response.json()\n",
    "\n",
    "    for predictions in data['predictions']:\n",
    "        Pred_points = predictions['points']\n",
    "        target_image = cv2.imread(target_image_path)\n",
    "        polygon_points = parse_polygon_annotation(Pred_points, target_image.shape)\n",
    "        polygon_image = extract_polygon_area(target_image_path, polygon_points)\n",
    "        target_dominant_colors = get_dominant_colors(polygon_image)\n",
    "    \n",
    "    return target_dominant_colors\n",
    "\n",
    "def match_images_with_target_colors(target_dominant_colors, images_folder, inference_server_address, project_id, version_number, color_threshold):\n",
    "    global prediction_counter, image_counter\n",
    "    total_matches = 0\n",
    "    matched_filepaths = []\n",
    "\n",
    "    extention_images = \".jpg\"\n",
    "    get_images = sorted(glob.glob(images_folder + '*' + extention_images))\n",
    "\n",
    "    for images in get_images:\n",
    "        t0 = time.time()\n",
    "        print(\"File path: \" + images)\n",
    "        img = cv2.imread(images)\n",
    "        with open(images, \"rb\") as f:\n",
    "            im_bytes = f.read()        \n",
    "        im_b64 = base64.b64encode(im_bytes).decode(\"utf8\")\n",
    "        headers = {\n",
    "            'Content-Type': 'application/json',\n",
    "            'Accept': 'application/json'\n",
    "        }\n",
    "\n",
    "        params = {\n",
    "            'api_key': '',\n",
    "        }\n",
    "        \n",
    "        response = requests.post(inference_server_address + project_id + str(version_number), params=params, headers=headers, data=im_b64)\n",
    "        data = response.json()\n",
    "\n",
    "        for predictions in data['predictions']:\n",
    "            prediction_counter += 1\n",
    "            image_counter += 1\n",
    "            Pred_points = predictions['points']\n",
    "            image = cv2.imread(images)\n",
    "            polygon_points = parse_polygon_annotation(Pred_points, image.shape)\n",
    "            polygon_image = extract_polygon_area(images, polygon_points)\n",
    "            dominant_colors = get_dominant_colors(polygon_image)\n",
    "            matches, matched_colors_list = count_color_matches(dominant_colors, target_dominant_colors, color_threshold)\n",
    "        \n",
    "        all_matched = all(value > 0 for value in matches.values())\n",
    "        \n",
    "        if all_matched:\n",
    "            matched_filepaths.append(images)\n",
    "            total_matches += 1\n",
    "\n",
    "    print(f\"\\nTotal images where all target colors matched: {total_matches}\")\n",
    "    print(f\"\\nMatched images where all target colors matched: {matched_filepaths}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![T-shirt](https://storage.googleapis.com/com-roboflow-marketing/tradesmen2.jpg)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run main function to compare base color logo with target colors and run anomaly detection"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def main():\n",
    "    target_image_path = \"TARGET_IMAGE_PATH\"\n",
    "    inference_server_address = \"http://detect.roboflow.com/\"\n",
    "    version_number = 1\n",
    "    project_id = \"PROJECT_ID\"\n",
    "    images_folder = \"IMAGE_FOLDER_PATH\"\n",
    "    # grab all the .jpg files\n",
    "    extention_images = \".jpg\"\n",
    "    get_images = sorted(glob.glob(images_folder + '*' + extention_images))\n",
    "    MAX_COLOR_DIFFERENCE = 3 * 256 # DO NOT EDIT\n",
    "    TARGET_COLOR_PERCENT_THRESHOLD= 0.08 # Value must be between 0 - 1 - DO EDIT\n",
    "    color_threshold = int(MAX_COLOR_DIFFERENCE * TARGET_COLOR_PERCENT_THRESHOLD)\n",
    "\n",
    "\n",
    "    target_dominant_colors = extract_target_colors(target_image_path,inference_server_address, project_id, version_number)\n",
    "    match_images_with_target_colors(target_dominant_colors, images_folder, inference_server_address, project_id, version_number, color_threshold)\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    main()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![Logo](https://storage.googleapis.com/com-roboflow-marketing/target_polygon_result1.jpg)"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
